Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Triton Cross Entropy

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, Optimization
Last Updated 2026-02-07 14:00 GMT

Overview

Implements efficient cross-entropy loss computation and gradient calculation using OpenAI Triton JIT kernels, with support for tensor parallelism and label smoothing.

Description

triton/cross_entropy.py provides three Triton JIT kernels:

  • online_softmax_kernel: Computes the online softmax numerically-stable max (m) and sum (d) values per row in a single pass using the online softmax algorithm (Milakov & Gimelshein, 2018). Also extracts the logit at the target index (X_y). Supports tensor parallelism where the vocabulary is sharded across TP ranks.
  • cross_entropy_kernel: Uses the aggregated m/d/X_y values (potentially from multiple TP ranks after allreduce) to compute the cross-entropy loss and in-place gradient of the input. Supports label smoothing and ignore indices.
  • element_mul_kernel: Performs element-wise multiplication for gradient scaling by the loss scalar.

The online softmax approach avoids materializing a full softmax output, reducing memory usage for large vocabularies.

Usage

Use for computing cross-entropy loss in language model training, especially with tensor parallelism where the vocabulary dimension is sharded.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/common/triton/cross_entropy.py
Lines
1--262

Signature

@triton.jit
def online_softmax_kernel(
    X_ptr, X_stride, Y_ptr, Y_stride,
    m_d_X_y_ptr, m_d_X_y_stride,
    rank, n_cols, ignore_idx, n_non_ignore,
    BLOCK_SIZE: tl.constexpr):
    ...

@triton.jit
def cross_entropy_kernel(
    X_ptr, X_stride, Y_ptr, Y_stride,
    loss_ptr, loss_stride, m_d_X_y_ptr, m_d_X_y_stride,
    rank, n_cols, ignore_idx, label_smoothing,
    n_non_ignore, BLOCK_SIZE: tl.constexpr):
    ...

@triton.jit
def element_mul_kernel(
    x_ptr, x_stride, grad_output_ptr, n_cols,
    BLOCK_SIZE: tl.constexpr):
    ...

Import

from transformer_engine.common.triton.cross_entropy import (
    online_softmax_kernel,
    cross_entropy_kernel,
    element_mul_kernel,
)

I/O Contract

Inputs

Name Type Required Description
X_ptr pointer Yes Logits tensor pointer
Y_ptr pointer Yes Target labels tensor pointer
rank int Yes TP rank for vocabulary sharding
n_cols int Yes Local vocabulary size (on this TP rank)
ignore_idx int Yes Index to ignore for loss calculation
label_smoothing float No Label smoothing factor

Outputs

Name Type Description
loss_ptr pointer Per-sample cross-entropy loss
m_d_X_y_ptr pointer Online softmax statistics (max, sum, target logit)

Usage Examples

from transformer_engine.common.triton.cross_entropy import (
    online_softmax_kernel, cross_entropy_kernel
)

# Step 1: Compute online softmax statistics (per TP rank)
online_softmax_kernel[(n_rows,)](
    X, X.stride(0), Y, Y.stride(0),
    m_d_X_y, m_d_X_y.stride(0),
    tp_rank, n_cols, ignore_idx, n_non_ignore,
    BLOCK_SIZE=1024,
)

# Step 2: Allreduce m_d_X_y across TP ranks (if TP > 1)

# Step 3: Compute loss and in-place gradients
cross_entropy_kernel[(n_rows,)](
    X, X.stride(0), Y, Y.stride(0),
    loss, loss.stride(0), m_d_X_y, m_d_X_y.stride(0),
    tp_rank, n_cols, ignore_idx, label_smoothing,
    n_non_ignore, BLOCK_SIZE=1024,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment