Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine FusedAdam

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Optimization, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

GPU-fused Adam/AdamW optimizer that batches elementwise updates across all parameters into minimal kernel launches, with support for FP8 primary weights, mixed-precision master weights, and CUDA Graph capture.

Description

FusedAdam extends torch.optim.Optimizer and uses multi_tensor_applier to launch batched CUDA kernels (multi_tensor_adam, multi_tensor_adam_fp8, multi_tensor_adam_capturable, multi_tensor_adam_capturable_master) from the C++ extension. Supports FP16/BF16/FP8 optimizer states (exp_avg_dtype, exp_avg_sq_dtype), FP32/FP16 master weights, BF16 parameter remainder storage for memory savings, decoupled gradients, and DTensor unwrapping for distributed tensors.

Usage

Drop-in replacement for torch.optim.Adam/AdamW that reduces kernel launch overhead through multi-tensor batching and enables FP8 optimizer state support for memory savings.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/optimizers/fused_adam.py
Lines
1--760

Signature

class FusedAdam(torch.optim.Optimizer):
    def __init__(
        self, params, lr=1e-3, bias_correction=True,
        betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True,
        weight_decay=0.0, set_grad_none=True,
        capturable=False, master_weights=False,
        exp_avg_dtype=None, exp_avg_sq_dtype=None, ...
    ): ...

    def step(self, closure=None, grad_scaler=None): ...

Import

from transformer_engine.pytorch.optimizers import FusedAdam

I/O Contract

Inputs

Name Type Required Description
params iterable Yes Iterable of parameters or parameter groups
lr float No Learning rate (default 1e-3)
betas tuple No Adam beta1 and beta2 (default (0.9, 0.999))
eps float No Epsilon for numerical stability (default 1e-8)
adam_w_mode bool No Use decoupled weight decay (AdamW, default True)
master_weights bool No Maintain FP32 master weights
exp_avg_dtype torch.dtype No Dtype for first moment (e.g., torch.float8_e4m3fn)
capturable bool No CUDA Graph capturable mode

Outputs

Name Type Description
loss Optional[torch.Tensor] Loss value if closure provided, else None

Usage Examples

from transformer_engine.pytorch.optimizers import FusedAdam

# Create FusedAdam optimizer with FP8 states
optimizer = FusedAdam(
    model.parameters(),
    lr=1e-4,
    adam_w_mode=True,
    weight_decay=0.01,
    master_weights=True,
    exp_avg_dtype=torch.float8_e4m3fn,
)

# Standard training loop
optimizer.zero_grad()
loss = model(input).sum()
loss.backward()
optimizer.step()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment