Implementation:Microsoft Onnxruntime CUDA LAMB Impl
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool declaring CUDA kernel function interfaces for the LAMB optimizer implementation in the ONNX Runtime CUDA training framework.
Description
Header file declaring the CUDA kernel function templates used by the LAMB optimizer. Two primary kernel functions are declared: (1) LambComputeDirection computes the LAMB update direction from weights, gradients, first and second moments, with optional loss scaling and gradient norm clipping. Parameters include alpha/beta (EMA coefficients), lambda (weight decay), epsilon (numerical stability), max_norm (gradient clipping threshold), and bias correction factors. It outputs the update direction and updated moments. (2) LambUpdate applies the trust ratio to the update direction and updates weights. It computes the ratio of weight L2 norm to update L2 norm (clamped by ratio_min/ratio_max), scales the update by learning rate and trust ratio, and optionally writes updated mixed-precision weights. Both functions also have multi-tensor functor versions (LambMultiTensorComputeDirectionFunctor, LambMultiTensorUpdateFunctor, LambMultiTensorReductionFunctor) for batched processing of multiple parameter groups.
Usage
Included by the LAMB optimizer implementation (lamb.cc) to access CUDA kernel launch functions. The actual kernel implementations reside in a .cu file.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.h
- Lines: 1-166
Signature
template <typename T1, typename T2, typename T3, typename T_GRAD_NORM>
void LambComputeDirection(
cudaStream_t stream, const T1* weights, const T2* grads,
const T3* moment_1, const T3* moment_2, const T1* loss_scale,
const T_GRAD_NORM* grad_norm, float alpha, float beta, float lambda,
float epsilon, float max_norm, float alpha_correction, float beta_correction,
T2* update_direction, T3* moment_1_out, T3* moment_2_out, size_t count);
template <typename T1, typename T2, typename T3, typename T_MIXED_PRECISION_FP>
void LambUpdate(
cudaStream_t stream, const T1* eta, const float ratio_min, const float ratio_max,
const T2* r_norm, const T2* w_norm, const T2* weights, const T3* update_direction,
T2* weights_out, T3* update_direction_out,
T_MIXED_PRECISION_FP* mixed_precision_weights_out, size_t count);
template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
struct LambMultiTensorComputeDirectionFunctor;
template <typename TIn1, typename TIn2, typename TOut1, typename TOut2, typename TBuf>
struct LambMultiTensorUpdateFunctor;
template <typename TIn, typename TOut>
struct LambMultiTensorReductionFunctor;
Import
#include "orttraining/training_ops/cuda/optimizer/lamb_impl.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| weights | T1* | Yes | Current parameter weights |
| grads | T2* | Yes | Current gradients |
| moment_1 | T3* | Yes | First moment (mean of gradients) |
| moment_2 | T3* | Yes | Second moment (mean of squared gradients) |
| loss_scale | T1* | No | Loss scale factor for mixed precision |
| grad_norm | T_GRAD_NORM* | No | Global gradient norm for clipping |
Outputs
| Name | Type | Description |
|---|---|---|
| update_direction | T2* | Computed LAMB update direction |
| moment_1_out | T3* | Updated first moment |
| moment_2_out | T3* | Updated second moment |
| weights_out | T2* | Updated weights (from LambUpdate) |
| mixed_precision_weights_out | T_MIXED_PRECISION_FP* | Updated mixed-precision weight copy |
Usage Examples
// Called from LambOptimizer::ComputeInternal
LambComputeDirection<float, float, float, float>(
stream, weights, grads, m1, m2, loss_scale, grad_norm,
alpha, beta, lambda, epsilon, max_norm,
alpha_correction, beta_correction,
update_dir, m1_out, m2_out, count);