Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Deepspeedai DeepSpeed Evoformer LogSumExp

From Leeroopedia


Knowledge Sources
Domains Attention, CUTLASS_Kernels, DeepSpeed4Science
Last Updated 2026-02-09 00:00 GMT

Overview

An epilogue thread operator that applies exp(x - logsumexp) transformation for converting attention logits to probabilities with numerical stability.

Description

ApplyLogSumExp implements the numerically stable softmax computation output = exp(logits - lse) used in attention mechanisms. Rather than computing raw exp(logits) which can overflow, this operator subtracts the logsumexp statistic (log of sum of exponentials) before exponentiation, ensuring the result stays in a numerically safe range. The template handles vectorized operations on fragments, applying element-wise subtraction followed by vectorized exponential functions. For half-precision (FP16) types, it uses specialized ArrayExponential that leverages GPU's native h2exp instruction on half2 pairs for efficiency. The operator integrates with CUTLASS's epilogue infrastructure, accepting the logsumexp values through the bias parameter channel and producing normalized probability outputs.

Usage

This operator is used in the attention backward pass when recomputing attention probabilities from saved logits and logsumexp statistics, converting stored Q×K^T scores back to the softmax-normalized attention weights P = softmax(Q×K^T).

Code Reference

Source Location

Signature

// Vectorized exponential for float
template <typename Element, int ElementsPerAccess>
struct ArrayExponential {
    CUTLASS_HOST_DEVICE
    Array<Element, ElementsPerAccess> operator()(
        Array<Element, ElementsPerAccess> const& input) const;
};

// Specialized vectorized exponential for half_t using h2exp
template <int ElementsPerAccess>
struct ArrayExponential<half_t, ElementsPerAccess> {
    CUTLASS_DEVICE
    Array<half_t, ElementsPerAccess> operator()(
        Array<half_t, ElementsPerAccess> const& input) const;
};

// Main epilogue operator
template <typename ElementOutput_,      // Output element type
          typename ElementLSE_,         // LSE element type
          typename ElementAccumulator_, // Accumulator type
          typename ElementCompute_,     // Compute type
          int ElementsPerAccess>
class ApplyLogSumExp {
public:
    using FragmentOutput = Array<ElementOutput, ElementsPerAccess>;
    using FragmentAccumulator = Array<ElementAccumulator, ElementsPerAccess>;
    using FragmentLSE = Array<ElementLSE, ElementsPerAccess>;

    CUTLASS_HOST_DEVICE FragmentOutput operator()(
        FragmentAccumulator const& AB,      // Attention logits
        FragmentLSE const& scale_unused,    // Unused scale
        FragmentLSE const& bias             // LSE values (passed as bias)
    ) const;
};

Import

#include "csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h"

I/O Contract

Parameter Type Description
Inputs
AB FragmentAccumulator Attention logits from Q×K^T GEMM
bias (LSE) FragmentLSE Logsumexp values per row
scale_unused FragmentLSE Unused scale parameter
Output
FragmentOutput Array<Element, N> Softmax probabilities exp(AB - LSE)
Operations
Subtract AB - LSE Numerically stable logit adjustment
Exp exp(adjusted) Exponential to get probabilities

Usage Examples

// Configure epilogue operator for FP16 attention backward
using LogSumExpOp = cutlass::epilogue::thread::ApplyLogSumExp<
    cutlass::half_t,   // Output probabilities
    float,             // LSE stored in FP32
    float,             // Accumulator (logits) in FP32
    float,             // Compute in FP32
    8                  // Process 8 elements at once
>;

// Use in epilogue when recomputing attention weights
// During backward: need P = softmax(S) where S = Q×K^T
// Forward saved: logsumexp[i] = log(sum_j exp(S[i,j]))
// Backward computes: P[i,j] = exp(S[i,j] - logsumexp[i])

using Epilogue = cutlass::epilogue::threadblock::EpiloguePipelined<
    Shape, WarpMma, 1,
    OutputIterator,
    AccumIterator,
    WarpTileIterator,
    SharedLoadIterator,
    LogSumExpOp,       // Apply log-sum-exp transformation
    Padding
>;

// Execute: reads S from GEMM, reads LSE from memory, outputs P
epilogue(output_op, destination, accumulators, logsumexp_source);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment