Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA LAMB Impl

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool declaring CUDA kernel function interfaces for the LAMB optimizer implementation in the ONNX Runtime CUDA training framework.

Description

Header file declaring the CUDA kernel function templates used by the LAMB optimizer. Two primary kernel functions are declared: (1) LambComputeDirection computes the LAMB update direction from weights, gradients, first and second moments, with optional loss scaling and gradient norm clipping. Parameters include alpha/beta (EMA coefficients), lambda (weight decay), epsilon (numerical stability), max_norm (gradient clipping threshold), and bias correction factors. It outputs the update direction and updated moments. (2) LambUpdate applies the trust ratio to the update direction and updates weights. It computes the ratio of weight L2 norm to update L2 norm (clamped by ratio_min/ratio_max), scales the update by learning rate and trust ratio, and optionally writes updated mixed-precision weights. Both functions also have multi-tensor functor versions (LambMultiTensorComputeDirectionFunctor, LambMultiTensorUpdateFunctor, LambMultiTensorReductionFunctor) for batched processing of multiple parameter groups.

Usage

Included by the LAMB optimizer implementation (lamb.cc) to access CUDA kernel launch functions. The actual kernel implementations reside in a .cu file.

Code Reference

Source Location

Signature

template <typename T1, typename T2, typename T3, typename T_GRAD_NORM>
void LambComputeDirection(
    cudaStream_t stream, const T1* weights, const T2* grads,
    const T3* moment_1, const T3* moment_2, const T1* loss_scale,
    const T_GRAD_NORM* grad_norm, float alpha, float beta, float lambda,
    float epsilon, float max_norm, float alpha_correction, float beta_correction,
    T2* update_direction, T3* moment_1_out, T3* moment_2_out, size_t count);

template <typename T1, typename T2, typename T3, typename T_MIXED_PRECISION_FP>
void LambUpdate(
    cudaStream_t stream, const T1* eta, const float ratio_min, const float ratio_max,
    const T2* r_norm, const T2* w_norm, const T2* weights, const T3* update_direction,
    T2* weights_out, T3* update_direction_out,
    T_MIXED_PRECISION_FP* mixed_precision_weights_out, size_t count);

template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
struct LambMultiTensorComputeDirectionFunctor;

template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
struct LambMultiTensorUpdateFunctor;

template <typename TIn, typename TOut>
struct LambMultiTensorReductionFunctor;

Import

#include "orttraining/training_ops/cuda/optimizer/lamb_impl.h"

I/O Contract

Inputs

Name Type Required Description
weights T1* Yes Current parameter weights
grads T2* Yes Current gradients
moment_1 T3* Yes First moment (mean of gradients)
moment_2 T3* Yes Second moment (mean of squared gradients)
loss_scale T1* No Loss scale factor for mixed precision
grad_norm T_GRAD_NORM* No Global gradient norm for clipping

Outputs

Name Type Description
update_direction T2* Computed LAMB update direction
moment_1_out T3* Updated first moment
moment_2_out T3* Updated second moment
weights_out T2* Updated weights (from LambUpdate)
mixed_precision_weights_out T_MIXED_PRECISION_FP* Updated mixed-precision weight copy

Usage Examples

// Called from LambOptimizer::ComputeInternal
LambComputeDirection<float, float, float, float>(
    stream, weights, grads, m1, m2, loss_scale, grad_norm,
    alpha, beta, lambda, epsilon, max_norm,
    alpha_correction, beta_correction,
    update_dir, m1_out, m2_out, count);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment