Implementation:NVIDIA TransformerEngine FusedSGD
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
GPU-fused SGD optimizer with optional momentum and Nesterov acceleration, batching updates across all parameters into minimal kernel launches.
Description
FusedSGD extends torch.optim.Optimizer and uses multi_tensor_applier with tex.multi_tensor_sgd to perform batched parameter updates. Supports momentum, dampening, Nesterov momentum, weight decay (both standard L2 and post-momentum variants), and mixed-precision training with master weights. Groups parameters by dtype and handles FP16/BF16 parameters with FP32 master copies.
Usage
Drop-in replacement for torch.optim.SGD that reduces kernel launch overhead through multi-tensor batching. Useful for fine-tuning and workloads that benefit from SGD's properties.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/optimizers/fused_sgd.py- Lines
- 1--316
Signature
class FusedSGD(torch.optim.Optimizer):
def __init__(
self, params, lr=0.1, momentum=0,
dampening=0, weight_decay=0,
nesterov=False, set_grad_none=True,
master_weights=False, ...
): ...
def step(self, closure=None): ...
Import
from transformer_engine.pytorch.optimizers import FusedSGD
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| params | iterable |
Yes | Iterable of parameters or parameter groups |
| lr | float |
No | Learning rate (default 0.1) |
| momentum | float |
No | Momentum factor (default 0) |
| dampening | float |
No | Dampening for momentum (default 0) |
| weight_decay | float |
No | Weight decay (L2 penalty, default 0) |
| nesterov | bool |
No | Enable Nesterov momentum (default False) |
| master_weights | bool |
No | Maintain FP32 master weights |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Optional[torch.Tensor] |
Loss value if closure provided, else None |
Usage Examples
from transformer_engine.pytorch.optimizers import FusedSGD
optimizer = FusedSGD(
model.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=1e-4,
nesterov=True,
)
optimizer.zero_grad()
loss = model(input).sum()
loss.backward()
optimizer.step()