Implementation:NVIDIA TransformerEngine PyTorch Triton Cross Entropy
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
PyTorch wrapper functions for the Triton-based fused cross entropy kernels, implementing online softmax and cross entropy loss computation with optional distributed support.
Description
This module provides the PyTorch interface to the Triton cross entropy kernels defined in transformer_engine.common.triton.cross_entropy. The forward pass uses a two-phase approach:
- Online softmax phase (
online_softmax_kernel) -- Computes running max, denominator, and target logit values in a single pass for numerical stability - Cross entropy phase (
cross_entropy_kernel) -- Computes the loss using the stabilized values, with optional label smoothing and distributed all-gather of softmax statistics
The backward pass uses element_mul_kernel for efficient gradient scaling. Supports distributed loss computation where the vocabulary dimension is split across ranks, with all_gather_into_tensor for synchronizing softmax statistics.
Usage
Called by the CrossEntropyFunction autograd function in transformer_engine.pytorch.cross_entropy. Not typically called directly.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/triton/cross_entropy.py- Lines
- 1--140
Signature
def cross_entropy_forward(_input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx) -> Tuple[torch.Tensor, torch.Tensor]: ...
def cross_entropy_backward(_input, grad_output, is_cg_capturable=False) -> torch.Tensor: ...
Import
from transformer_engine.pytorch.triton.cross_entropy import cross_entropy_forward, cross_entropy_backward
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| _input | torch.Tensor | Yes | Logits of shape (B, SQ, V)
|
| target | torch.Tensor | Yes | Target indices of shape (B, SQ)
|
| label_smoothing | float | Yes | Smoothing factor (0.0 = no smoothing) |
| reduce_loss | bool | Yes | If True, return averaged scalar loss |
| dist_process_group | ProcessGroup | No | Distributed group for vocab-parallel loss |
| ignore_idx | int | Yes | Index to ignore in loss computation |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Cross entropy loss (scalar if reduced, else shape (B, SQ))
|
| _input | torch.Tensor | Modified input tensor (contains gradients for backward) |
Usage Examples
from transformer_engine.pytorch.triton.cross_entropy import cross_entropy_forward
loss, modified_input = cross_entropy_forward(
logits,
targets,
label_smoothing=0.1,
reduce_loss=True,
dist_process_group=None,
ignore_idx=-100,
)