Implementation:NVIDIA TransformerEngine Triton Cross Entropy
Appearance
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements efficient cross-entropy loss computation and gradient calculation using OpenAI Triton JIT kernels, with support for tensor parallelism and label smoothing.
Description
triton/cross_entropy.py provides three Triton JIT kernels:
- online_softmax_kernel: Computes the online softmax numerically-stable max (m) and sum (d) values per row in a single pass using the online softmax algorithm (Milakov & Gimelshein, 2018). Also extracts the logit at the target index (X_y). Supports tensor parallelism where the vocabulary is sharded across TP ranks.
- cross_entropy_kernel: Uses the aggregated m/d/X_y values (potentially from multiple TP ranks after allreduce) to compute the cross-entropy loss and in-place gradient of the input. Supports label smoothing and ignore indices.
- element_mul_kernel: Performs element-wise multiplication for gradient scaling by the loss scalar.
The online softmax approach avoids materializing a full softmax output, reducing memory usage for large vocabularies.
Usage
Use for computing cross-entropy loss in language model training, especially with tensor parallelism where the vocabulary dimension is sharded.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/common/triton/cross_entropy.py- Lines
- 1--262
Signature
@triton.jit
def online_softmax_kernel(
X_ptr, X_stride, Y_ptr, Y_stride,
m_d_X_y_ptr, m_d_X_y_stride,
rank, n_cols, ignore_idx, n_non_ignore,
BLOCK_SIZE: tl.constexpr):
...
@triton.jit
def cross_entropy_kernel(
X_ptr, X_stride, Y_ptr, Y_stride,
loss_ptr, loss_stride, m_d_X_y_ptr, m_d_X_y_stride,
rank, n_cols, ignore_idx, label_smoothing,
n_non_ignore, BLOCK_SIZE: tl.constexpr):
...
@triton.jit
def element_mul_kernel(
x_ptr, x_stride, grad_output_ptr, n_cols,
BLOCK_SIZE: tl.constexpr):
...
Import
from transformer_engine.common.triton.cross_entropy import (
online_softmax_kernel,
cross_entropy_kernel,
element_mul_kernel,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
X_ptr |
pointer | Yes | Logits tensor pointer |
Y_ptr |
pointer | Yes | Target labels tensor pointer |
rank |
int |
Yes | TP rank for vocabulary sharding |
n_cols |
int |
Yes | Local vocabulary size (on this TP rank) |
ignore_idx |
int |
Yes | Index to ignore for loss calculation |
label_smoothing |
float |
No | Label smoothing factor |
Outputs
| Name | Type | Description |
|---|---|---|
loss_ptr |
pointer | Per-sample cross-entropy loss |
m_d_X_y_ptr |
pointer | Online softmax statistics (max, sum, target logit) |
Usage Examples
from transformer_engine.common.triton.cross_entropy import (
online_softmax_kernel, cross_entropy_kernel
)
# Step 1: Compute online softmax statistics (per TP rank)
online_softmax_kernel[(n_rows,)](
X, X.stride(0), Y, Y.stride(0),
m_d_X_y, m_d_X_y.stride(0),
tp_rank, n_cols, ignore_idx, n_non_ignore,
BLOCK_SIZE=1024,
)
# Step 2: Allreduce m_d_X_y across TP ranks (if TP > 1)
# Step 3: Compute loss and in-place gradients
cross_entropy_kernel[(n_rows,)](
X, X.stride(0), Y, Y.stride(0),
loss, loss.stride(0), m_d_X_y, m_d_X_y.stride(0),
tp_rank, n_cols, ignore_idx, label_smoothing,
n_non_ignore, BLOCK_SIZE=1024,
)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment