Implementation:Deepspeedai DeepSpeed Fused LAMB Frontend
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep Learning, GPU Computing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
C++ frontend for the fused LAMB (Layer-wise Adaptive Moments optimizer for Batch training) CUDA implementation providing PyTorch bindings.
Description
This file serves as the C++ interface layer between PyTorch and the CUDA-accelerated LAMB optimizer implementation. LAMB is designed for large-batch training, extending BERT-style models to very large batch sizes by applying layer-wise adaptation. The frontend performs input validation, creates intermediate tensors for L2 norm reductions, and dispatches to the CUDA kernel implementation. It handles both FP16 and FP32 precision modes and returns the computed trust ratio coefficient for monitoring training stability.
Usage
Use this optimizer when training large models with very large batch sizes on GPU, particularly for BERT-style architectures where layer-wise adaptation improves convergence.
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/lamb/fused_lamb_cuda.cpp
Signature
at::Tensor lamb(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay);
Import
#include <torch/extension.h>
I/O Contract
lamb Function Parameters
| Parameter | Type | Description |
|---|---|---|
| p | at::Tensor& | Model parameters (in/out) |
| p_copy | at::Tensor& | Optional parameter copy for mixed precision (in/out) |
| m | at::Tensor& | First moment estimates (in/out) |
| v | at::Tensor& | Second moment estimates (in/out) |
| g | at::Tensor& | Gradients (in) |
| lr | float | Learning rate |
| beta1 | float | Exponential decay rate for first moment |
| beta2 | float | Exponential decay rate for second moment |
| max_coeff | float | Maximum trust ratio coefficient |
| min_coeff | float | Minimum trust ratio coefficient |
| eps | float | Small constant for numerical stability |
| grad_scale | float | Gradient scaling factor for mixed precision |
| step | int | Current training step number |
| mode | int | Optimizer mode (0: Adam-style, 1: LAMB-style) |
| bias_correction | int | Enable bias correction (0 or 1) |
| decay | float | Weight decay coefficient |
Returns
| Return Type | Description |
|---|---|
| at::Tensor | Scalar tensor containing the computed LAMB coefficient (trust ratio) |
Tensor Requirements
| Requirement | Description |
|---|---|
| Device | All tensors must be on CUDA device |
| Layout | All tensors must be contiguous |
| Shape | p, m, v, g must have same shape; p_copy can be empty or same shape |
| Dtype | FP16 or FP32 (mixed precision supported with p_copy) |
Usage Examples
import torch
import deepspeed
# Prepare tensors on GPU
params = torch.randn(1024, device='cuda', dtype=torch.float32)
params_copy = torch.empty(0) # Optional mixed precision copy
momentum = torch.zeros_like(params)
variance = torch.zeros_like(params)
grads = torch.randn_like(params)
# Execute LAMB optimizer step
lamb_coeff = deepspeed.ops.lamb.fused_lamb.lamb(
p=params,
p_copy=params_copy,
m=momentum,
v=variance,
g=grads,
lr=0.001,
beta1=0.9,
beta2=0.999,
max_coeff=10.0,
min_coeff=0.01,
eps=1e-6,
grad_scale=1.0,
step=1,
mode=1, # LAMB mode
bias_correction=1,
decay=0.01
)
print(f"LAMB coefficient (trust ratio): {lamb_coeff.item()}")