Implementation:Deepspeedai DeepSpeed Evoformer LogSumExp
| Knowledge Sources | |
|---|---|
| Domains | Attention, CUTLASS_Kernels, DeepSpeed4Science |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
An epilogue thread operator that applies exp(x - logsumexp) transformation for converting attention logits to probabilities with numerical stability.
Description
ApplyLogSumExp implements the numerically stable softmax computation output = exp(logits - lse) used in attention mechanisms. Rather than computing raw exp(logits) which can overflow, this operator subtracts the logsumexp statistic (log of sum of exponentials) before exponentiation, ensuring the result stays in a numerically safe range. The template handles vectorized operations on fragments, applying element-wise subtraction followed by vectorized exponential functions. For half-precision (FP16) types, it uses specialized ArrayExponential that leverages GPU's native h2exp instruction on half2 pairs for efficiency. The operator integrates with CUTLASS's epilogue infrastructure, accepting the logsumexp values through the bias parameter channel and producing normalized probability outputs.
Usage
This operator is used in the attention backward pass when recomputing attention probabilities from saved logits and logsumexp statistics, converting stored Q×K^T scores back to the softmax-normalized attention weights P = softmax(Q×K^T).
Code Reference
Source Location
- Repository: DeepSpeed
- File: csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h
Signature
// Vectorized exponential for float
template <typename Element, int ElementsPerAccess>
struct ArrayExponential {
CUTLASS_HOST_DEVICE
Array<Element, ElementsPerAccess> operator()(
Array<Element, ElementsPerAccess> const& input) const;
};
// Specialized vectorized exponential for half_t using h2exp
template <int ElementsPerAccess>
struct ArrayExponential<half_t, ElementsPerAccess> {
CUTLASS_DEVICE
Array<half_t, ElementsPerAccess> operator()(
Array<half_t, ElementsPerAccess> const& input) const;
};
// Main epilogue operator
template <typename ElementOutput_, // Output element type
typename ElementLSE_, // LSE element type
typename ElementAccumulator_, // Accumulator type
typename ElementCompute_, // Compute type
int ElementsPerAccess>
class ApplyLogSumExp {
public:
using FragmentOutput = Array<ElementOutput, ElementsPerAccess>;
using FragmentAccumulator = Array<ElementAccumulator, ElementsPerAccess>;
using FragmentLSE = Array<ElementLSE, ElementsPerAccess>;
CUTLASS_HOST_DEVICE FragmentOutput operator()(
FragmentAccumulator const& AB, // Attention logits
FragmentLSE const& scale_unused, // Unused scale
FragmentLSE const& bias // LSE values (passed as bias)
) const;
};
Import
#include "csrc/deepspeed4science/evoformer_attn/epilogue/epilogue_thread_apply_logsumexp.h"
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| Inputs | ||
| AB | FragmentAccumulator | Attention logits from Q×K^T GEMM |
| bias (LSE) | FragmentLSE | Logsumexp values per row |
| scale_unused | FragmentLSE | Unused scale parameter |
| Output | ||
| FragmentOutput | Array<Element, N> | Softmax probabilities exp(AB - LSE) |
| Operations | ||
| Subtract | AB - LSE | Numerically stable logit adjustment |
| Exp | exp(adjusted) | Exponential to get probabilities |
Usage Examples
// Configure epilogue operator for FP16 attention backward
using LogSumExpOp = cutlass::epilogue::thread::ApplyLogSumExp<
cutlass::half_t, // Output probabilities
float, // LSE stored in FP32
float, // Accumulator (logits) in FP32
float, // Compute in FP32
8 // Process 8 elements at once
>;
// Use in epilogue when recomputing attention weights
// During backward: need P = softmax(S) where S = Q×K^T
// Forward saved: logsumexp[i] = log(sum_j exp(S[i,j]))
// Backward computes: P[i,j] = exp(S[i,j] - logsumexp[i])
using Epilogue = cutlass::epilogue::threadblock::EpiloguePipelined<
Shape, WarpMma, 1,
OutputIterator,
AccumIterator,
WarpTileIterator,
SharedLoadIterator,
LogSumExpOp, // Apply log-sum-exp transformation
Padding
>;
// Execute: reads S from GEMM, reads LSE from memory, outputs P
epilogue(output_op, destination, accumulators, logsumexp_source);