Implementation:NVIDIA TransformerEngine FusedAdam
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
GPU-fused Adam/AdamW optimizer that batches elementwise updates across all parameters into minimal kernel launches, with support for FP8 primary weights, mixed-precision master weights, and CUDA Graph capture.
Description
FusedAdam extends torch.optim.Optimizer and uses multi_tensor_applier to launch batched CUDA kernels (multi_tensor_adam, multi_tensor_adam_fp8, multi_tensor_adam_capturable, multi_tensor_adam_capturable_master) from the C++ extension. Supports FP16/BF16/FP8 optimizer states (exp_avg_dtype, exp_avg_sq_dtype), FP32/FP16 master weights, BF16 parameter remainder storage for memory savings, decoupled gradients, and DTensor unwrapping for distributed tensors.
Usage
Drop-in replacement for torch.optim.Adam/AdamW that reduces kernel launch overhead through multi-tensor batching and enables FP8 optimizer state support for memory savings.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/optimizers/fused_adam.py- Lines
- 1--760
Signature
class FusedAdam(torch.optim.Optimizer):
def __init__(
self, params, lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
weight_decay=0.0, set_grad_none=True,
capturable=False, master_weights=False,
exp_avg_dtype=None, exp_avg_sq_dtype=None, ...
): ...
def step(self, closure=None, grad_scaler=None): ...
Import
from transformer_engine.pytorch.optimizers import FusedAdam
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| params | iterable |
Yes | Iterable of parameters or parameter groups |
| lr | float |
No | Learning rate (default 1e-3) |
| betas | tuple |
No | Adam beta1 and beta2 (default (0.9, 0.999)) |
| eps | float |
No | Epsilon for numerical stability (default 1e-8) |
| adam_w_mode | bool |
No | Use decoupled weight decay (AdamW, default True) |
| master_weights | bool |
No | Maintain FP32 master weights |
| exp_avg_dtype | torch.dtype |
No | Dtype for first moment (e.g., torch.float8_e4m3fn) |
| capturable | bool |
No | CUDA Graph capturable mode |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Optional[torch.Tensor] |
Loss value if closure provided, else None |
Usage Examples
from transformer_engine.pytorch.optimizers import FusedAdam
# Create FusedAdam optimizer with FP8 states
optimizer = FusedAdam(
model.parameters(),
lr=1e-4,
adam_w_mode=True,
weight_decay=0.01,
master_weights=True,
exp_avg_dtype=torch.float8_e4m3fn,
)
# Standard training loop
optimizer.zero_grad()
loss = model(input).sum()
loss.backward()
optimizer.step()