Implementation:Microsoft Onnxruntime CUDA LAMB
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for the LAMB (Layer-wise Adaptive Moments) optimizer in the ONNX Runtime CUDA training framework.
Description
Implements the LambOptimizer operator for CUDA that performs LAMB optimization across multiple weight groups in a single kernel launch. The optimizer supports up to 1024 parameter groups with in-place updates via alias mappings generated by GenerateLambExtraAliasMapping. The ComputeInternal method processes weights, gradients, momentum-1, momentum-2, and optional mixed-precision weights per group. The multi-step computation includes: (1) compute update direction using first and second moments with bias correction and optional loss scaling/gradient norm clipping; (2) compute per-group weight and update L2 norms via reduce_square_sum; (3) apply the LAMB trust ratio (clamped between ratio_min and ratio_max) to scale the update; (4) update weights and optionally copy to mixed-precision format. The implementation uses multi-tensor functors for efficient batched GPU computation. Registered for multiple mixed-precision type combinations including float and MLFloat16.
Usage
Invoked during the optimization step of distributed training when LAMB optimizer is selected, commonly used for large-batch training of BERT and similar models.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/optimizer/lamb.cc
- Lines: 1-726
Signature
std::vector<std::pair<int, int>> GenerateLambExtraAliasMapping();
template <typename T1, typename T2, typename T3, typename T4, typename T_GRAD, typename T_MIXED_PRECISION_FP>
class LambOptimizer : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/optimizer/lamb.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| eta | Tensor(T1) | Yes | Learning rate |
| loss_scale | Tensor(T3) | No | Loss scale factor for mixed precision |
| gradient_norm | Tensor(T_GRAD_NORM) | No | Global gradient norm for clipping |
| do_update | Tensor(bool) | No | Whether to perform the update (CPU input) |
| update_count | Tensor(T4) | Yes | Iteration count for bias correction |
| weights_i | Tensor(T2) | Yes | Parameter weights per group (up to 1024 groups) |
| gradients_i | Tensor(T_GRAD) | Yes | Gradients per group |
| moment_1_i | Tensor(T3) | Yes | First moment per group |
| moment_2_i | Tensor(T3) | Yes | Second moment per group |
| mixed_precision_weights_i | Tensor(T_MIXED_PRECISION_FP) | No | Mixed-precision weight copy per group |
Outputs
| Name | Type | Description |
|---|---|---|
| update_count_new | Tensor(T4) | Updated iteration count |
| weights_new_i | Tensor(T2) | Updated weights per group (in-place) |
| gradients_new_i | Tensor(T_GRAD) | Updated gradients per group (in-place) |
| moment_1_new_i | Tensor(T3) | Updated first moment per group (in-place) |
| moment_2_new_i | Tensor(T3) | Updated second moment per group (in-place) |
| mixed_precision_weights_new_i | Tensor(T_MIXED_PRECISION_FP) | Updated mixed-precision weights per group (in-place) |
Usage Examples
// Registration for float weights, float gradients, MLFloat16 mixed precision
REGISTER_LAMB_KERNEL_TYPED(float, float, float, float, float, MLFloat16)
// Multi-tensor LAMB update is launched with:
// Stage 1: LambComputeDirection (compute d, update m1, m2)
// Stage 2: reduce_square_sum for w_norm and d_norm
// Stage 3: LambUpdate (apply trust ratio and update weights)